{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.02946984, -0.02137133,  0.04337265, -0.03912106], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApE0lEQVR4nO3df3RU9Z3/8ddMfgyEMBMDJJNIgigIRgh2AcNUa+mSEhCtrHGPWlawy5Ejm3gKsRTTpSp2j3F1z/qjVfhju+KeI0XtV7RSwSJIWGtATEn5JSmwtEHJJAjfzEBsfs18vn+4zPmOhB8TQuYz5Pk4556TuZ/P3Hnfz0mYF/d+7r0OY4wRAACARZzxLgAAAODrCCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDpxDSgvvviirrrqKg0YMEBFRUX6+OOP41kOAACwRNwCymuvvaaKigo99thj+sMf/qAJEyaopKREzc3N8SoJAABYwhGvhwUWFRVp8uTJ+sUvfiFJCofDysvL00MPPaRHHnkkHiUBAABLJMfjQzs6OlRbW6vKysrIOqfTqeLiYtXU1JzRv729Xe3t7ZHX4XBYJ06c0JAhQ+RwOPqkZgAAcHGMMTp58qRyc3PldJ77JE5cAsoXX3yhUCik7OzsqPXZ2dnav3//Gf2rqqq0fPnyvioPAABcQkeOHNHw4cPP2ScuASVWlZWVqqioiLwOBALKz8/XkSNH5Ha741gZAAC4UMFgUHl5eRo8ePB5+8YloAwdOlRJSUlqamqKWt/U1CSv13tGf5fLJZfLdcZ6t9tNQAEAIMFcyPSMuFzFk5qaqokTJ2rTpk2RdeFwWJs2bZLP54tHSQAAwCJxO8VTUVGhefPmadKkSbrxxhv13HPPqbW1VT/4wQ/iVRIAALBE3ALK3XffrWPHjunRRx+V3+/XDTfcoA0bNpwxcRYAAPQ/cbsPysUIBoPyeDwKBALMQQEAIEHE8v3Ns3gAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKzT6wHl8ccfl8PhiFrGjh0baW9ra1NZWZmGDBmi9PR0lZaWqqmpqbfLAAAACeySHEG5/vrr1djYGFk+/PDDSNvixYv1zjvv6I033lB1dbWOHj2qO++881KUAQAAElTyJdlocrK8Xu8Z6wOBgH75y19q9erV+tu//VtJ0ssvv6zrrrtO27Zt05QpUy5FOQAAIMFckiMoBw4cUG5urq6++mrNmTNHDQ0NkqTa2lp1dnaquLg40nfs2LHKz89XTU3NWbfX3t6uYDAYtQAAgMtXrweUoqIirVq1Shs2bNCKFSt0+PBhfetb39LJkyfl9/uVmpqqjIyMqPdkZ2fL7/efdZtVVVXyeDyRJS8vr7fLBgAAFun1UzwzZ86M/FxYWKiioiKNGDFCr7/+ugYOHNijbVZWVqqioiLyOhgMElIAALiMXfLLjDMyMnTttdfq4MGD8nq96ujoUEtLS1SfpqambuesnOZyueR2u6MWAABw+brkAeXUqVM6dOiQcnJyNHHiRKWkpGjTpk2R9vr6ejU0NMjn813qUgAAQILo9VM8P/rRj3T77bdrxIgROnr0qB577DElJSXp3nvvlcfj0fz581VRUaHMzEy53W499NBD8vl8XMEDAAAiej2gfPbZZ7r33nt1/PhxDRs2TDfffLO2bdumYcOGSZKeffZZOZ1OlZaWqr29XSUlJXrppZd6uwwAAJDAHMYYE+8iYhUMBuXxeBQIBJiPAgBAgojl+5tn8QAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArBNzQNm6datuv/125ebmyuFw6K233opqN8bo0UcfVU5OjgYOHKji4mIdOHAgqs+JEyc0Z84cud1uZWRkaP78+Tp16tRF7QgAALh8xBxQWltbNWHCBL344ovdtj/99NN64YUXtHLlSm3fvl2DBg1SSUmJ2traIn3mzJmjvXv3auPGjVq3bp22bt2qBQsW9HwvAADAZcVhjDE9frPDobVr12r27NmSvjp6kpubq4cfflg/+tGPJEmBQEDZ2dlatWqV7rnnHn366acqKCjQjh07NGnSJEnShg0bdOutt+qzzz5Tbm7ueT83GAzK4/EoEAjI7Xb3tHwAANCHYvn+7tU5KIcPH5bf71dxcXFkncfjUVFRkWpqaiRJNTU1ysjIiIQTSSouLpbT6dT27du73W57e7uCwWDUAgAALl+9GlD8fr8kKTs7O2p9dnZ2pM3v9ysrKyuqPTk5WZmZmZE+X1dVVSWPxxNZ8vLyerNsAABgmYS4iqeyslKBQCCyHDlyJN4lAQCAS6hXA4rX65UkNTU1Ra1vamqKtHm9XjU3N0e1d3V16cSJE5E+X+dyueR2u6MWAABw+erVgDJy5Eh5vV5t2rQpsi4YDGr79u3y+XySJJ/Pp5aWFtXW1kb6bN68WeFwWEVFRb1ZDgAASFDJsb7h1KlTOnjwYOT14cOHVVdXp8zMTOXn52vRokX6l3/5F40ePVojR47UT3/6U+Xm5kau9Lnuuus0Y8YMPfDAA1q5cqU6OztVXl6ue+6554Ku4AEAAJe/mAPKJ598ou985zuR1xUVFZKkefPmadWqVfrxj3+s1tZWLViwQC0tLbr55pu1YcMGDRgwIPKeV199VeXl5Zo2bZqcTqdKS0v1wgsv9MLuAACAy8FF3QclXrgPCgAAiSdu90EBAADoDQQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWiTmgbN26Vbfffrtyc3PlcDj01ltvRbXff//9cjgcUcuMGTOi+pw4cUJz5syR2+1WRkaG5s+fr1OnTl3UjgAAgMtHzAGltbVVEyZM0IsvvnjWPjNmzFBjY2Nk+dWvfhXVPmfOHO3du1cbN27UunXrtHXrVi1YsCD26gEAwGUpOdY3zJw5UzNnzjxnH5fLJa/X223bp59+qg0bNmjHjh2aNGmSJOnnP/+5br31Vv3bv/2bcnNzYy0JAABcZi7JHJQtW7YoKytLY8aM0cKFC3X8+PFIW01NjTIyMiLhRJKKi4vldDq1ffv2brfX3t6uYDAYtQAAgMtXrweUGTNm6L/+67+0adMm/eu//quqq6s1c+ZMhUIhSZLf71dWVlbUe5KTk5WZmSm/39/tNquqquTxeCJLXl5eb5cNAAAsEvMpnvO55557Ij+PHz9ehYWFuuaaa7RlyxZNmzatR9usrKxURUVF5HUwGCSkAABwGbvklxlfffXVGjp0qA4ePChJ8nq9am5ujurT1dWlEydOnHXeisvlktvtjloAAMDl65IHlM8++0zHjx9XTk6OJMnn86mlpUW1tbWRPps3b1Y4HFZRUdGlLgcAACSAmE/xnDp1KnI0RJIOHz6suro6ZWZmKjMzU8uXL1dpaam8Xq8OHTqkH//4xxo1apRKSkokSdddd51mzJihBx54QCtXrlRnZ6fKy8t1zz33cAUPAACQJDmMMSaWN2zZskXf+c53zlg/b948rVixQrNnz9bOnTvV0tKi3NxcTZ8+XT/72c+UnZ0d6XvixAmVl5frnXfekdPpVGlpqV544QWlp6dfUA3BYFAej0eBQIDTPQAAJIhYvr9jDig2IKAAAJB4Yvn+5lk8AADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGCdmB8WCAC9rfXYX/T5J785Zx/X4CEacfP3+6giAPFGQAEQV8YYdba2KNCw+5z9BmZe2UcVAbABp3gAxJlRONQZ7yIAWIaAAiC+jFGoqyPeVQCwDAEFQFwZSYaAAuBrCCgA4stwigfAmQgoAOLMKNxFQAEQjYACIL44ggKgGwQUAHHFHBQA3SGgAIgvjqAA6AYBBUB8GaMwR1AAfA0BBUBchTrbdPzA9nN3cjg1rODbfVMQACsQUADEnQmHztnucDiUnDqwj6oBYAMCCoCE4ExOjXcJAPoQAQVAQiCgAP0LAQVAQiCgAP0LAQVAAnAQUIB+hoACICEQUID+hYACwH4OyZniincVAPoQAQVAQuAICtC/xBRQqqqqNHnyZA0ePFhZWVmaPXu26uvro/q0tbWprKxMQ4YMUXp6ukpLS9XU1BTVp6GhQbNmzVJaWpqysrK0ZMkSdXV1XfzeALhMOeRMSol3EQD6UEwBpbq6WmVlZdq2bZs2btyozs5OTZ8+Xa2trZE+ixcv1jvvvKM33nhD1dXVOnr0qO68885IeygU0qxZs9TR0aGPPvpIr7zyilatWqVHH3209/YKQEIwxuirxwWeH0dQgP7FYb76F6JHjh07pqysLFVXV+uWW25RIBDQsGHDtHr1at11112SpP379+u6665TTU2NpkyZovXr1+u2227T0aNHlZ2dLUlauXKlli5dqmPHjik19fz/CAWDQXk8HgUCAbnd7p6WDyDOjDFqDzRr92s/PWc/R1KKJs7/hRwORx9VBuBSiOX7+6LmoAQCAUlSZmamJKm2tladnZ0qLi6O9Bk7dqzy8/NVU1MjSaqpqdH48eMj4USSSkpKFAwGtXfv3m4/p729XcFgMGoBcHkI8aBAAN3ocUAJh8NatGiRbrrpJo0bN06S5Pf7lZqaqoyMjKi+2dnZ8vv9kT7/fzg53X66rTtVVVXyeDyRJS8vr6dlA7BMuLM93iUAsFCPA0pZWZn27NmjNWvW9GY93aqsrFQgEIgsR44cueSfCaBvhLsIKADOlNyTN5WXl2vdunXaunWrhg8fHlnv9XrV0dGhlpaWqKMoTU1N8nq9kT4ff/xx1PZOX+Vzus/XuVwuuVzcAwG4HIU5xQOgGzEdQTHGqLy8XGvXrtXmzZs1cuTIqPaJEycqJSVFmzZtiqyrr69XQ0ODfD6fJMnn82n37t1qbm6O9Nm4caPcbrcKCgouZl8AJCDmoADoTkxHUMrKyrR69Wq9/fbbGjx4cGTOiMfj0cCBA+XxeDR//nxVVFQoMzNTbrdbDz30kHw+n6ZMmSJJmj59ugoKCnTffffp6aeflt/v17Jly1RWVsZREqAfYg4KgO7EFFBWrFghSZo6dWrU+pdffln333+/JOnZZ5+V0+lUaWmp2tvbVVJSopdeeinSNykpSevWrdPChQvl8/k0aNAgzZs3T0888cTF7QmAhMQpHgDduaj7oMQL90EBLg/GGDXuXK/Pd7x1zn7cBwW4PPTZfVAA4GK1n/zivH1SBvIfEaC/IaAAiKvjB7adt8/QMd/sg0oA2ISAAsB6zhQm0AP9DQEFgPWSeFAg0O8QUABYjycZA/0PAQWA9ZzJnOIB+hsCCgDrcQQF6H8IKACsxyRZoP8hoACwnjOFIyhAf0NAAWC9JOagAP0OAQWA9ZiDAvQ/BBQAcWPCoQvq50hK5jk8QD9DQAEQNybUGe8SAFiKgAIgbkJdHVLCPU8dQF8goACIm3Bne7xLAGApAgqAuAl3dYhDKAC6Q0ABEDdfBRQAOBMBBUDcEFAAnA0BBUDcEFAAnA0BBUDchLs6ZAxzUACciYACIG5CHEEBcBYEFABxc/Lz/dJ5jqCkDbtKTmdyH1UEwBYEFABx8+UXDTrfZcaDhubLkURAAfobAgoAqzmTUySewwP0OwQUAFZzJqWIeAL0PwQUAFZzJKdKRBSg3yGgALCaM4lTPEB/REABYDVnckq8SwAQBwQUAFZzJqfKwREUoN8hoACwmiOJOShAf0RAAWC1JC4zBvqlmAJKVVWVJk+erMGDBysrK0uzZ89WfX19VJ+pU6fK4XBELQ8++GBUn4aGBs2aNUtpaWnKysrSkiVL1NXVdfF7A+Cy40xiDgrQH8V0e8bq6mqVlZVp8uTJ6urq0k9+8hNNnz5d+/bt06BBgyL9HnjgAT3xxBOR12lpaZGfQ6GQZs2aJa/Xq48++kiNjY2aO3euUlJS9OSTT/bCLgFIBOFQl8x57iIrSXI4mYMC9EMxBZQNGzZEvV61apWysrJUW1urW265JbI+LS1NXq+322387ne/0759+/T+++8rOztbN9xwg372s59p6dKlevzxx5WamtqD3QCQaEyo67zP4QHQf13UHJRAICBJyszMjFr/6quvaujQoRo3bpwqKyv15ZdfRtpqamo0fvx4ZWdnR9aVlJQoGAxq79693X5Oe3u7gsFg1AIgsYVDnTIEFABn0eMncIXDYS1atEg33XSTxo0bF1n//e9/XyNGjFBubq527dqlpUuXqr6+Xm+++aYkye/3R4UTSZHXfr+/28+qqqrS8uXLe1oqAAuFuzolE453GQAs1eOAUlZWpj179ujDDz+MWr9gwYLIz+PHj1dOTo6mTZumQ4cO6ZprrunRZ1VWVqqioiLyOhgMKi8vr2eFA7CCCXVyigfAWfXoFE95ebnWrVunDz74QMOHDz9n36KiIknSwYMHJUler1dNTU1RfU6/Ptu8FZfLJbfbHbUASGzhUBeneACcVUwBxRij8vJyrV27Vps3b9bIkSPP+566ujpJUk5OjiTJ5/Np9+7dam5ujvTZuHGj3G63CgoKYikHQAIzoU7pQq7iAdAvxXSKp6ysTKtXr9bbb7+twYMHR+aMeDweDRw4UIcOHdLq1at16623asiQIdq1a5cWL16sW265RYWFhZKk6dOnq6CgQPfdd5+efvpp+f1+LVu2TGVlZXK5XL2/hwCs9NUkWeagAOheTEdQVqxYoUAgoKlTpyonJyeyvPbaa5Kk1NRUvf/++5o+fbrGjh2rhx9+WKWlpXrnnXci20hKStK6deuUlJQkn8+nf/iHf9DcuXOj7psC4PIX5jJjAOcQ0xGU850vzsvLU3V19Xm3M2LECL377ruxfDSAywyXGQM4F57FAyAuAn/+o7raTp6zT3rOtXK5h/RRRQBsQkABEBfhUMd5T/EkpwyQw8mzeID+iIACwFqOpGQ5nDyHB+iPCCgArOVISpbDwT9TQH/EXz4AazmcSRIBBeiX+MsHYC1nUrIcTv6ZAvoj/vIBWMvh5BQP0F/xlw/AWo6kZMmZFO8yAMQBAQWAtZxMkgX6Lf7yAViLUzxA/8VfPoA+d6G3uHckJUsO7oMC9EcEFAB9zoRDMuHQBfV1EFCAfomAAqDPfRVQwvEuA4DFCCgA+lwsR1AA9E8EFAB9joAC4HwIKAD6nAl3yRhO8QA4OwIKgD5nQhxBAXBuBBQAfY5TPADOh4ACoM9xFQ+A80mOdwEAEosxRqHQxR396OrskAl3nf+zwmF1dZ2/39kkJSVxHxUgQRFQAMQkHA7L4/Goo6Ojx9sovDpbS+/1aWTOFWft09kV0tLKn+i1D2b3+HP27dun0aNH9/j9AOKHgAIgZl1dXRd1ZGP08IxzhhNJ+ktTQHUHGi/qcy70lvoA7ENAARA3xkhNHSPUGrpCRg4NdAaV7fqLkh1dCoXC6goxTwXorwgoAOJmb+vNau7IV0d4oIwcSnW06fP2MZrsXq9QOKxOAgrQb3EVD4A+Z4xTe0/dpM/axqo9nC6jJElOdZg0He+8UtsC31NnyKGuLi5FBvorAgqAPtfQVqCGtgKZbv8JcqilK0u1ge+os4sjKEB/RUABEAeO/13O3h4KGU7xAP0YAQWAlULhsLou8n4rABIXAQWAlbrCYU7xAP0YAQVAn8sb8KlyXX+S1N19SozSk07o+rQPOMUD9GMxBZQVK1aosLBQbrdbbrdbPp9P69evj7S3tbWprKxMQ4YMUXp6ukpLS9XU1BS1jYaGBs2aNUtpaWnKysrSkiVLLupGTAASj9PRpcL0LfKm/o9SHH+VQ2FJYSU72uVO+kI3Z/wfOcLt6uIICtBvxXQflOHDh+upp57S6NGjZYzRK6+8ojvuuEM7d+7U9ddfr8WLF+u3v/2t3njjDXk8HpWXl+vOO+/U73//e0lSKBTSrFmz5PV69dFHH6mxsVFz585VSkqKnnzyyUuygwDs8z9H/6/e/v1+Sfv1edtoBbuGysihQUktunLAAb3t6NT+hi8U5k6wQL/lMBd5L+jMzEw988wzuuuuuzRs2DCtXr1ad911lyRp//79uu6661RTU6MpU6Zo/fr1uu2223T06FFlZ2dLklauXKmlS5fq2LFjSk1NvaDPDAaD8ng8uv/++y/4PQB6hzFGv/zlLxVOgKcR33333fJ4PPEuA8D/6ujo0KpVqxQIBOR2u8/Zt8d3kg2FQnrjjTfU2toqn8+n2tpadXZ2qri4ONJn7Nixys/PjwSUmpoajR8/PhJOJKmkpEQLFy7U3r179Y1vfKPbz2pvb1d7e3vkdTAYlCTdd999Sk9P7+kuAOgBY4xWrVqVEAHl7//+75WXlxfvMgD8r1OnTmnVqlUX1DfmgLJ79275fD61tbUpPT1da9euVUFBgerq6pSamqqMjIyo/tnZ2fL7/ZIkv98fFU5Ot59uO5uqqiotX778jPWTJk06bwID0LtCoZAcjnPdw8Qe48eP17XXXhvvMgD8r9MHGC5EzFfxjBkzRnV1ddq+fbsWLlyoefPmad++fbFuJiaVlZUKBAKR5ciRI5f08wAAQHzFfAQlNTVVo0aNkiRNnDhRO3bs0PPPP6+7775bHR0damlpiTqK0tTUJK/XK0nyer36+OOPo7Z3+iqf032643K55HK5Yi0VAAAkqIu+D0o4HFZ7e7smTpyolJQUbdq0KdJWX1+vhoYG+Xw+SZLP59Pu3bvV3Nwc6bNx40a53W4VFBRcbCkAAOAyEdMRlMrKSs2cOVP5+fk6efKkVq9erS1btui9996Tx+PR/PnzVVFRoczMTLndbj300EPy+XyaMmWKJGn69OkqKCjQfffdp6efflp+v1/Lli1TWVkZR0gAAEBETAGlublZc+fOVWNjozwejwoLC/Xee+/pu9/9riTp2WefldPpVGlpqdrb21VSUqKXXnop8v6kpCStW7dOCxculM/n06BBgzRv3jw98cQTvbtXAAAgoV30fVDi4fR9UC7kOmoAvSsUCiktLU0dHR3xLuW86uvruYoHsEgs3988iwcAAFiHgAIAAKxDQAEAANYhoAAAAOv0+Fk8APonh8OhO+64Q52dnfEu5bx4VheQuAgoAGLidDr1+uuvx7sMAJc5TvEAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWiSmgrFixQoWFhXK73XK73fL5fFq/fn2kferUqXI4HFHLgw8+GLWNhoYGzZo1S2lpacrKytKSJUvU1dXVO3sDAAAuC8mxdB4+fLieeuopjR49WsYYvfLKK7rjjju0c+dOXX/99ZKkBx54QE888UTkPWlpaZGfQ6GQZs2aJa/Xq48++kiNjY2aO3euUlJS9OSTT/bSLgEAgETnMMaYi9lAZmamnnnmGc2fP19Tp07VDTfcoOeee67bvuvXr9dtt92mo0ePKjs7W5K0cuVKLV26VMeOHVNqauoFfWYwGJTH41EgEJDb7b6Y8gEAQB+J5fu7x3NQQqGQ1qxZo9bWVvl8vsj6V199VUOHDtW4ceNUWVmpL7/8MtJWU1Oj8ePHR8KJJJWUlCgYDGrv3r1n/az29nYFg8GoBQAAXL5iOsUjSbt375bP51NbW5vS09O1du1aFRQUSJK+//3va8SIEcrNzdWuXbu0dOlS1dfX680335Qk+f3+qHAiKfLa7/ef9TOrqqq0fPnyWEsFAAAJKuaAMmbMGNXV1SkQCOjXv/615s2bp+rqahUUFGjBggWRfuPHj1dOTo6mTZumQ4cO6ZprrulxkZWVlaqoqIi8DgaDysvL6/H2AACA3WI+xZOamqpRo0Zp4sSJqqqq0oQJE/T8889327eoqEiSdPDgQUmS1+tVU1NTVJ/Tr71e71k/0+VyRa4cOr0AAIDL10XfByUcDqu9vb3btrq6OklSTk6OJMnn82n37t1qbm6O9Nm4caPcbnfkNBEAAEBMp3gqKys1c+ZM5efn6+TJk1q9erW2bNmi9957T4cOHdLq1at16623asiQIdq1a5cWL16sW265RYWFhZKk6dOnq6CgQPfdd5+efvpp+f1+LVu2TGVlZXK5XJdkBwEAQOKJKaA0Nzdr7ty5amxslMfjUWFhod577z1997vf1ZEjR/T+++/rueeeU2trq/Ly8lRaWqply5ZF3p+UlKR169Zp4cKF8vl8GjRokObNmxd13xQAAICLvg9KPHAfFAAAEk+f3AcFAADgUiGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWSY53AT1hjJEkBYPBOFcCAAAu1Onv7dPf4+eSkAHl5MmTkqS8vLw4VwIAAGJ18uRJeTyec/ZxmAuJMZYJh8Oqr69XQUGBjhw5IrfbHe+SElYwGFReXh7j2AsYy97DWPYOxrH3MJa9wxijkydPKjc3V07nuWeZJOQRFKfTqSuvvFKS5Ha7+WXpBYxj72Esew9j2TsYx97DWF688x05OY1JsgAAwDoEFAAAYJ2EDSgul0uPPfaYXC5XvEtJaIxj72Esew9j2TsYx97DWPa9hJwkCwAALm8JewQFAABcvggoAADAOgQUAABgHQIKAACwTkIGlBdffFFXXXWVBgwYoKKiIn388cfxLsk6W7du1e23367c3Fw5HA699dZbUe3GGD366KPKycnRwIEDVVxcrAMHDkT1OXHihObMmSO3262MjAzNnz9fp06d6sO9iL+qqipNnjxZgwcPVlZWlmbPnq36+vqoPm1tbSorK9OQIUOUnp6u0tJSNTU1RfVpaGjQrFmzlJaWpqysLC1ZskRdXV19uStxtWLFChUWFkZucuXz+bR+/fpIO2PYc0899ZQcDocWLVoUWcd4XpjHH39cDocjahk7dmyknXGMM5Ng1qxZY1JTU81//ud/mr1795oHHnjAZGRkmKampniXZpV3333X/PM//7N58803jSSzdu3aqPannnrKeDwe89Zbb5k//vGP5nvf+54ZOXKk+etf/xrpM2PGDDNhwgSzbds289///d9m1KhR5t577+3jPYmvkpIS8/LLL5s9e/aYuro6c+utt5r8/Hxz6tSpSJ8HH3zQ5OXlmU2bNplPPvnETJkyxXzzm9+MtHd1dZlx48aZ4uJis3PnTvPuu++aoUOHmsrKynjsUlz85je/Mb/97W/Nn/70J1NfX29+8pOfmJSUFLNnzx5jDGPYUx9//LG56qqrTGFhofnhD38YWc94XpjHHnvMXH/99aaxsTGyHDt2LNLOOMZXwgWUG2+80ZSVlUVeh0Ihk5uba6qqquJYld2+HlDC4bDxer3mmWeeiaxraWkxLpfL/OpXvzLGGLNv3z4jyezYsSPSZ/369cbhcJjPP/+8z2q3TXNzs5FkqqurjTFfjVtKSop54403In0+/fRTI8nU1NQYY74Ki06n0/j9/kifFStWGLfbbdrb2/t2ByxyxRVXmP/4j/9gDHvo5MmTZvTo0Wbjxo3m29/+diSgMJ4X7rHHHjMTJkzoto1xjL+EOsXT0dGh2tpaFRcXR9Y5nU4VFxerpqYmjpUllsOHD8vv90eNo8fjUVFRUWQca2pqlJGRoUmTJkX6FBcXy+l0avv27X1esy0CgYAkKTMzU5JUW1urzs7OqLEcO3as8vPzo8Zy/Pjxys7OjvQpKSlRMBjU3r17+7B6O4RCIa1Zs0atra3y+XyMYQ+VlZVp1qxZUeMm8TsZqwMHDig3N1dXX3215syZo4aGBkmMow0S6mGBX3zxhUKhUNQvgyRlZ2dr//79caoq8fj9fknqdhxPt/n9fmVlZUW1JycnKzMzM9KnvwmHw1q0aJFuuukmjRs3TtJX45SamqqMjIyovl8fy+7G+nRbf7F79275fD61tbUpPT1da9euVUFBgerq6hjDGK1Zs0Z/+MMftGPHjjPa+J28cEVFRVq1apXGjBmjxsZGLV++XN/61re0Z88extECCRVQgHgqKyvTnj179OGHH8a7lIQ0ZswY1dXVKRAI6Ne//rXmzZun6urqeJeVcI4cOaIf/vCH2rhxowYMGBDvchLazJkzIz8XFhaqqKhII0aM0Ouvv66BAwfGsTJICXYVz9ChQ5WUlHTGLOqmpiZ5vd44VZV4To/VucbR6/Wqubk5qr2rq0snTpzol2NdXl6udevW6YMPPtDw4cMj671erzo6OtTS0hLV/+tj2d1Yn27rL1JTUzVq1ChNnDhRVVVVmjBhgp5//nnGMEa1tbVqbm7W3/zN3yg5OVnJycmqrq7WCy+8oOTkZGVnZzOePZSRkaFrr71WBw8e5PfSAgkVUFJTUzVx4kRt2rQpsi4cDmvTpk3y+XxxrCyxjBw5Ul6vN2ocg8Ggtm/fHhlHn8+nlpYW1dbWRvps3rxZ4XBYRUVFfV5zvBhjVF5errVr12rz5s0aOXJkVPvEiROVkpISNZb19fVqaGiIGsvdu3dHBb6NGzfK7XaroKCgb3bEQuFwWO3t7YxhjKZNm6bdu3errq4uskyaNElz5syJ/Mx49sypU6d06NAh5eTk8Htpg3jP0o3VmjVrjMvlMqtWrTL79u0zCxYsMBkZGVGzqPHVDP+dO3eanTt3Gknm3//9383OnTvNX/7yF2PMV5cZZ2RkmLffftvs2rXL3HHHHd1eZvyNb3zDbN++3Xz44Ydm9OjR/e4y44ULFxqPx2O2bNkSdSnil19+Genz4IMPmvz8fLN582bzySefGJ/PZ3w+X6T99KWI06dPN3V1dWbDhg1m2LBh/epSxEceecRUV1ebw4cPm127dplHHnnEOBwO87vf/c4YwxherP//Kh5jGM8L9fDDD5stW7aYw4cPm9///vemuLjYDB061DQ3NxtjGMd4S7iAYowxP//5z01+fr5JTU01N954o9m2bVu8S7LOBx98YCSdscybN88Y89Wlxj/96U9Ndna2cblcZtq0aaa+vj5qG8ePHzf33nuvSU9PN2632/zgBz8wJ0+ejMPexE93YyjJvPzyy5E+f/3rX80//dM/mSuuuMKkpaWZv/u7vzONjY1R2/nzn/9sZs6caQYOHGiGDh1qHn74YdPZ2dnHexM///iP/2hGjBhhUlNTzbBhw8y0adMi4cQYxvBifT2gMJ4X5u677zY5OTkmNTXVXHnllebuu+82Bw8ejLQzjvHlMMaY+By7AQAA6F5CzUEBAAD9AwEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANb5f2jaIUh2D2nDAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.4559, 0.5441],\n",
       "         [0.4245, 0.5755]], grad_fn=<SoftmaxBackward0>),\n",
       " tensor([[-0.3215],\n",
       "         [-0.3096]], grad_fn=<AddmmBackward0>))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#定义模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    "    torch.nn.Softmax(dim=1),\n",
    ")\n",
    "\n",
    "model_td = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 1),\n",
    ")\n",
    "\n",
    "model(torch.randn(2, 4)), model_td(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "#得到一个动作\n",
    "def get_action(state):\n",
    "    state = torch.FloatTensor(state).reshape(1, 4)\n",
    "    #[1, 4] -> [1, 2]\n",
    "    prob = model(state)\n",
    "\n",
    "    #根据概率选择一个动作\n",
    "    action = random.choices(range(2), weights=prob[0].tolist(), k=1)[0]\n",
    "\n",
    "    return action\n",
    "\n",
    "\n",
    "get_action([1, 2, 3, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\cgq10\\AppData\\Local\\Temp\\ipykernel_19216\\2726165283.py:31: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:248.)\n",
      "  states = torch.FloatTensor(states).reshape(-1, 4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[ 6.3693e-03,  4.7836e-03, -1.5849e-02,  3.0064e-02],\n",
       "         [ 6.4649e-03, -1.9011e-01, -1.5248e-02,  3.1770e-01],\n",
       "         [ 2.6628e-03,  5.2282e-03, -8.8934e-03,  2.0252e-02],\n",
       "         [ 2.7673e-03,  2.0048e-01, -8.4884e-03, -2.7522e-01],\n",
       "         [ 6.7769e-03,  5.4768e-03, -1.3993e-02,  1.4771e-02],\n",
       "         [ 6.8864e-03, -1.8944e-01, -1.3697e-02,  3.0301e-01],\n",
       "         [ 3.0976e-03,  5.8727e-03, -7.6373e-03,  6.0348e-03],\n",
       "         [ 3.2150e-03, -1.8914e-01, -7.5166e-03,  2.9630e-01],\n",
       "         [-5.6775e-04, -3.8415e-01, -1.5906e-03,  5.8660e-01],\n",
       "         [-8.2508e-03, -1.8901e-01,  1.0141e-02,  2.9342e-01],\n",
       "         [-1.2031e-02,  5.9672e-03,  1.6010e-02,  3.9503e-03],\n",
       "         [-1.1912e-02,  2.0086e-01,  1.6089e-02, -2.8364e-01],\n",
       "         [-7.8945e-03,  3.9574e-01,  1.0416e-02, -5.7120e-01],\n",
       "         [ 2.0379e-05,  5.9072e-01, -1.0081e-03, -8.6059e-01],\n",
       "         [ 1.1835e-02,  7.8585e-01, -1.8220e-02, -1.1536e+00],\n",
       "         [ 2.7552e-02,  5.9098e-01, -4.1292e-02, -8.6667e-01],\n",
       "         [ 3.9371e-02,  7.8663e-01, -5.8625e-02, -1.1720e+00],\n",
       "         [ 5.5104e-02,  5.9232e-01, -8.2066e-02, -8.9830e-01],\n",
       "         [ 6.6950e-02,  3.9840e-01, -1.0003e-01, -6.3250e-01],\n",
       "         [ 7.4919e-02,  2.0481e-01, -1.1268e-01, -3.7292e-01],\n",
       "         [ 7.9015e-02,  1.1451e-02, -1.2014e-01, -1.1779e-01],\n",
       "         [ 7.9244e-02,  2.0807e-01, -1.2250e-01, -4.4583e-01],\n",
       "         [ 8.3405e-02,  4.0469e-01, -1.3141e-01, -7.7448e-01],\n",
       "         [ 9.1499e-02,  6.0136e-01, -1.4690e-01, -1.1055e+00],\n",
       "         [ 1.0353e-01,  7.9807e-01, -1.6901e-01, -1.4404e+00],\n",
       "         [ 1.1949e-01,  9.9482e-01, -1.9782e-01, -1.7808e+00]]),\n",
       " tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]]),\n",
       " tensor([[0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0]]),\n",
       " tensor([[ 6.4649e-03, -1.9011e-01, -1.5248e-02,  3.1770e-01],\n",
       "         [ 2.6628e-03,  5.2282e-03, -8.8934e-03,  2.0252e-02],\n",
       "         [ 2.7673e-03,  2.0048e-01, -8.4884e-03, -2.7522e-01],\n",
       "         [ 6.7769e-03,  5.4768e-03, -1.3993e-02,  1.4771e-02],\n",
       "         [ 6.8864e-03, -1.8944e-01, -1.3697e-02,  3.0301e-01],\n",
       "         [ 3.0976e-03,  5.8727e-03, -7.6373e-03,  6.0348e-03],\n",
       "         [ 3.2150e-03, -1.8914e-01, -7.5166e-03,  2.9630e-01],\n",
       "         [-5.6775e-04, -3.8415e-01, -1.5906e-03,  5.8660e-01],\n",
       "         [-8.2508e-03, -1.8901e-01,  1.0141e-02,  2.9342e-01],\n",
       "         [-1.2031e-02,  5.9672e-03,  1.6010e-02,  3.9503e-03],\n",
       "         [-1.1912e-02,  2.0086e-01,  1.6089e-02, -2.8364e-01],\n",
       "         [-7.8945e-03,  3.9574e-01,  1.0416e-02, -5.7120e-01],\n",
       "         [ 2.0379e-05,  5.9072e-01, -1.0081e-03, -8.6059e-01],\n",
       "         [ 1.1835e-02,  7.8585e-01, -1.8220e-02, -1.1536e+00],\n",
       "         [ 2.7552e-02,  5.9098e-01, -4.1292e-02, -8.6667e-01],\n",
       "         [ 3.9371e-02,  7.8663e-01, -5.8625e-02, -1.1720e+00],\n",
       "         [ 5.5104e-02,  5.9232e-01, -8.2066e-02, -8.9830e-01],\n",
       "         [ 6.6950e-02,  3.9840e-01, -1.0003e-01, -6.3250e-01],\n",
       "         [ 7.4919e-02,  2.0481e-01, -1.1268e-01, -3.7292e-01],\n",
       "         [ 7.9015e-02,  1.1451e-02, -1.2014e-01, -1.1779e-01],\n",
       "         [ 7.9244e-02,  2.0807e-01, -1.2250e-01, -4.4583e-01],\n",
       "         [ 8.3405e-02,  4.0469e-01, -1.3141e-01, -7.7448e-01],\n",
       "         [ 9.1499e-02,  6.0136e-01, -1.4690e-01, -1.1055e+00],\n",
       "         [ 1.0353e-01,  7.9807e-01, -1.6901e-01, -1.4404e+00],\n",
       "         [ 1.1949e-01,  9.9482e-01, -1.9782e-01, -1.7808e+00],\n",
       "         [ 1.3938e-01,  8.0240e-01, -2.3343e-01, -1.5555e+00]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1]]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_data():\n",
    "    states = []\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    next_states = []\n",
    "    overs = []\n",
    "\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "        #记录数据样本\n",
    "        states.append(state)\n",
    "        rewards.append(reward)\n",
    "        actions.append(action)\n",
    "        next_states.append(next_state)\n",
    "        overs.append(over)\n",
    "\n",
    "        #更新游戏状态,开始下一个动作\n",
    "        state = next_state\n",
    "\n",
    "    #[b, 4]\n",
    "    states = torch.FloatTensor(states).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    rewards = torch.FloatTensor(rewards).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    actions = torch.LongTensor(actions).reshape(-1, 1)\n",
    "    #[b, 4]\n",
    "    next_states = torch.FloatTensor(next_states).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    overs = torch.LongTensor(overs).reshape(-1, 1)\n",
    "\n",
    "    return states, rewards, actions, next_states, overs\n",
    "\n",
    "\n",
    "get_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "29.0"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step(action)\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[8.090483997483998, 8.690100963999999, 8.260044, 6.724, 4.0]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#优势函数\n",
    "def get_advantages(deltas):\n",
    "    advantages = []\n",
    "\n",
    "    #反向遍历deltas\n",
    "    s = 0.0\n",
    "    for delta in deltas[::-1]:\n",
    "        s = 0.98 * 0.95 * s + delta\n",
    "        advantages.append(s)\n",
    "\n",
    "    #逆序\n",
    "    advantages.reverse()\n",
    "    return advantages\n",
    "\n",
    "\n",
    "get_advantages(range(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 20.0\n",
      "50 181.8\n",
      "100 193.6\n",
      "150 200.0\n",
      "200 200.0\n",
      "250 200.0\n",
      "300 200.0\n",
      "350 200.0\n",
      "400 190.5\n",
      "450 191.6\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "    optimizer_td = torch.optim.Adam(model_td.parameters(), lr=1e-2)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #玩N局游戏,每局游戏训练M次\n",
    "    for epoch in range(500):\n",
    "        #玩一局游戏,得到数据\n",
    "        #states -> [b, 4]\n",
    "        #rewards -> [b, 1]\n",
    "        #actions -> [b, 1]\n",
    "        #next_states -> [b, 4]\n",
    "        #overs -> [b, 1]\n",
    "        states, rewards, actions, next_states, overs = get_data()\n",
    "\n",
    "        #计算values和targets\n",
    "        #[b, 4] -> [b, 1]\n",
    "        values = model_td(states)\n",
    "\n",
    "        #[b, 4] -> [b, 1]\n",
    "        targets = model_td(next_states).detach()\n",
    "        targets = targets * 0.98\n",
    "        targets *= (1 - overs)\n",
    "        targets += rewards\n",
    "\n",
    "        #计算优势,这里的advantages有点像是策略梯度里的reward_sum\n",
    "        #只是这里计算的不是reward,而是target和value的差\n",
    "        #[b, 1]\n",
    "        deltas = (targets - values).squeeze(dim=1).tolist()\n",
    "        advantages = get_advantages(deltas)\n",
    "        advantages = torch.FloatTensor(advantages).reshape(-1, 1)\n",
    "\n",
    "        #取出每一步动作的概率\n",
    "        #[b, 2] -> [b, 2] -> [b, 1]\n",
    "        old_probs = model(states)\n",
    "        old_probs = old_probs.gather(dim=1, index=actions)\n",
    "        old_probs = old_probs.detach()\n",
    "\n",
    "        #每批数据反复训练10次\n",
    "        for _ in range(10):\n",
    "            #重新计算每一步动作的概率\n",
    "            #[b, 4] -> [b, 2]\n",
    "            new_probs = model(states)\n",
    "            #[b, 2] -> [b, 1]\n",
    "            new_probs = new_probs.gather(dim=1, index=actions)\n",
    "            new_probs = new_probs\n",
    "\n",
    "            #求出概率的变化\n",
    "            #[b, 1] - [b, 1] -> [b, 1]\n",
    "            ratios = new_probs / old_probs\n",
    "\n",
    "            #计算截断的和不截断的两份loss,取其中小的\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr1 = ratios * advantages\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr2 = torch.clamp(ratios, 0.8, 1.2) * advantages\n",
    "\n",
    "            loss = -torch.min(surr1, surr2)\n",
    "            loss = loss.mean()\n",
    "\n",
    "            #重新计算value,并计算时序差分loss\n",
    "            values = model_td(states)\n",
    "            loss_td = loss_fn(values, targets)\n",
    "\n",
    "            #更新参数\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            optimizer_td.zero_grad()\n",
    "            loss_td.backward()\n",
    "            optimizer_td.step()\n",
    "\n",
    "        if epoch % 50 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(10)]) / 10\n",
    "            print(epoch, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqNklEQVR4nO3df3RU9Z3/8ddMJjMkhJkQIJmkJIhChQjBFjBMrS5dUgKilTWeo5YV7HLkK038FmMtpmtVbI9xdc/6o6vwx+6Ke46U1n5FVypYBAlrjYApWQJqKizdYMkkSDaZJJBfM5/vHy6zHYlJJr/mJnk+zrmHzL3vufO+n5MzeXF/2owxRgAAABZij3UDAAAAX0RAAQAAlkNAAQAAlkNAAQAAlkNAAQAAlkNAAQAAlkNAAQAAlkNAAQAAlkNAAQAAlkNAAQAAlhPTgPL888/rsssu07hx45Sbm6tDhw7Fsh0AAGARMQsov/zlL1VcXKxHHnlEv//97zVv3jzl5+ervr4+Vi0BAACLsMXqYYG5ublauHCh/vEf/1GSFAqFlJmZqXvvvVcPPvhgLFoCAAAW4YjFh3Z0dKiiokIlJSXheXa7XXl5eSovL7+kvr29Xe3t7eHXoVBIDQ0NmjRpkmw227D0DAAABsYYo+bmZmVkZMhu7/kgTkwCymeffaZgMKi0tLSI+Wlpafr4448vqS8tLdWmTZuGqz0AADCETp8+ralTp/ZYE5OAEq2SkhIVFxeHXzc1NSkrK0unT5+W2+2OYWcAAKCvAoGAMjMzNWHChF5rYxJQJk+erLi4ONXV1UXMr6urk9frvaTe5XLJ5XJdMt/tdhNQAAAYYfpyekZMruJxOp2aP3++9u7dG54XCoW0d+9e+Xy+WLQEAAAsJGaHeIqLi7VmzRotWLBA11xzjZ555hm1trbqe9/7XqxaAgAAFhGzgHLbbbfp7Nmzevjhh+X3+3X11Vdr9+7dl5w4CwAAxp6Y3QdlIAKBgDwej5qamjgHBQCAESKav988iwcAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFgOAQUAAFjOoAeURx99VDabLWKaNWtWeHlbW5sKCws1adIkJSUlqaCgQHV1dYPdBgAAGMGGZA/KVVddpdra2vD07rvvhpfdd999euONN/TKK6+orKxMZ86c0S233DIUbQAAgBHKMSQrdTjk9Xovmd/U1KR//ud/1rZt2/SXf/mXkqQXX3xRs2fP1vvvv69FixYNRTsAAGCEGZI9KJ988okyMjJ0+eWXa9WqVaqpqZEkVVRUqLOzU3l5eeHaWbNmKSsrS+Xl5V+6vvb2dgUCgYgJAACMXoMeUHJzc7V161bt3r1bmzdv1qlTp3TdddepublZfr9fTqdTycnJEe9JS0uT3+//0nWWlpbK4/GEp8zMzMFuGwAAWMigH+JZvnx5+OecnBzl5uZq2rRp+tWvfqWEhIR+rbOkpETFxcXh14FAgJACAMAoNuSXGScnJ+urX/2qTpw4Ia/Xq46ODjU2NkbU1NXVdXvOykUul0tutztiAgAAo9eQB5SWlhadPHlS6enpmj9/vuLj47V3797w8urqatXU1Mjn8w11KwAAYIQY9EM8P/zhD3XTTTdp2rRpOnPmjB555BHFxcXpjjvukMfj0dq1a1VcXKyUlBS53W7de++98vl8XMEDAADCBj2gfPrpp7rjjjt07tw5TZkyRd/85jf1/vvva8qUKZKkp59+Wna7XQUFBWpvb1d+fr5eeOGFwW4DAACMYDZjjIl1E9EKBALyeDxqamrifBQAAEaIaP5+8yweAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOQQUAABgOVEHlAMHDuimm25SRkaGbDabXnvttYjlxhg9/PDDSk9PV0JCgvLy8vTJJ59E1DQ0NGjVqlVyu91KTk7W2rVr1dLSMqANAQAAo0fUAaW1tVXz5s3T888/3+3yJ598Us8995y2bNmigwcPavz48crPz1dbW1u4ZtWqVTp+/Lj27NmjnTt36sCBA1q3bl3/twIAAIwqNmOM6febbTbt2LFDK1eulPT53pOMjAzdf//9+uEPfyhJampqUlpamrZu3arbb79dH330kbKzs3X48GEtWLBAkrR7927dcMMN+vTTT5WRkdHr5wYCAXk8HjU1Ncntdve3fQAAMIyi+fs9qOegnDp1Sn6/X3l5eeF5Ho9Hubm5Ki8vlySVl5crOTk5HE4kKS8vT3a7XQcPHux2ve3t7QoEAhETAAAYvQY1oPj9fklSWlpaxPy0tLTwMr/fr9TU1IjlDodDKSkp4ZovKi0tlcfjCU+ZmZmD2TYAALCYEXEVT0lJiZqamsLT6dOnY90SAAAYQoMaULxerySprq4uYn5dXV14mdfrVX19fcTyrq4uNTQ0hGu+yOVyye12R0wAAGD0GtSAMn36dHm9Xu3duzc8LxAI6ODBg/L5fJIkn8+nxsZGVVRUhGv27dunUCik3NzcwWwHAACMUI5o39DS0qITJ06EX586dUqVlZVKSUlRVlaWNmzYoJ/97GeaOXOmpk+frp/85CfKyMgIX+kze/ZsLVu2THfffbe2bNmizs5OFRUV6fbbb+/TFTwAAGD0izqgfPDBB/rWt74Vfl1cXCxJWrNmjbZu3aof/ehHam1t1bp169TY2KhvfvOb2r17t8aNGxd+z8svv6yioiItWbJEdrtdBQUFeu655wZhcwAAwGgwoPugxAr3QQEAYOSJ2X1QAAAABgMBBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWE7UAeXAgQO66aablJGRIZvNptdeey1i+V133SWbzRYxLVu2LKKmoaFBq1atktvtVnJystauXauWlpYBbQgAABg9og4ora2tmjdvnp5//vkvrVm2bJlqa2vD0y9+8YuI5atWrdLx48e1Z88e7dy5UwcOHNC6deui7x4AAIxKjmjfsHz5ci1fvrzHGpfLJa/X2+2yjz76SLt379bhw4e1YMECSdLPf/5z3XDDDfr7v/97ZWRkRNsSAAAYZYbkHJT9+/crNTVVV155pdavX69z586Fl5WXlys5OTkcTiQpLy9PdrtdBw8e7HZ97e3tCgQCERMAABi9Bj2gLFu2TP/6r/+qvXv36u/+7u9UVlam5cuXKxgMSpL8fr9SU1Mj3uNwOJSSkiK/39/tOktLS+XxeMJTZmbmYLcNAAAsJOpDPL25/fbbwz/PnTtXOTk5uuKKK7R//34tWbKkX+ssKSlRcXFx+HUgECCkAAAwig35ZcaXX365Jk+erBMnTkiSvF6v6uvrI2q6urrU0NDwpeetuFwuud3uiAkAAIxeQx5QPv30U507d07p6emSJJ/Pp8bGRlVUVIRr9u3bp1AopNzc3KFuBwAAjABRH+JpaWkJ7w2RpFOnTqmyslIpKSlKSUnRpk2bVFBQIK/Xq5MnT+pHP/qRZsyYofz8fEnS7NmztWzZMt19993asmWLOjs7VVRUpNtvv50reAAAgCTJZowx0bxh//79+ta3vnXJ/DVr1mjz5s1auXKljhw5osbGRmVkZGjp0qX66U9/qrS0tHBtQ0ODioqK9MYbb8hut6ugoEDPPfeckpKS+tRDIBCQx+NRU1MTh3sAABghovn7HXVAsQICCgAAI080f795Fg8AALAcAgoAALAcAgoAALAcAgoAALAcAgoAALAcAgoAALAcAgoAALAcAgoAALAcAgoAALAcAgoAALCcqB8WCGDsOd/wJ5354A05EiYoPmGCHOPcik+cIMe4CYpPSFJ8gltxrvGy2WyxbhXAKEFAAdAjY4w6mj/Tf5/6vWSPk81m/3yyf/6v/uffcZ40zfrOD2PdLoBRgoACoGcmpM4LLZ//HArKKKjunjBqjx83rG0BGN04BwVAj4wxCrafj3UbAMYYAgqAXhh1tbfGugkAYwwBBUDPjFFXW0usuwAwxhBQAPQo2NWhz/5Q3mudzR43DN0AGCsIKAB6ZoxMsLPnGptN6VcvG55+AIwJBBQAg8AmR0JSrJsAMIoQUAAMivhxE2LdAoBRhIACYMBsEntQAAwqAgqAL2WMUairl/NPJMlmk2McAQXA4CGgAOhR54VAn+rs9vgh7gTAWEJAAdAj7oECIBYIKAB61NXWHOsWAIxBBBQAPeo8T0ABMPwIKAB6VHvkzV5r7A7nMHQCYCwhoADokQkFe61Jm5sn2WzD0A2AsYKAAmDAuMQYwGAjoAAYsPgE7iILYHBFFVBKS0u1cOFCTZgwQampqVq5cqWqq6sjatra2lRYWKhJkyYpKSlJBQUFqquri6ipqanRihUrlJiYqNTUVD3wwAPq6uoa+NYAiAkHt7kHMMiiCihlZWUqLCzU+++/rz179qizs1NLly5Va2truOa+++7TG2+8oVdeeUVlZWU6c+aMbrnllvDyYDCoFStWqKOjQ++9955eeuklbd26VQ8//PDgbRWAQWL6VMVt7gEMNpsxpm/fQN04e/asUlNTVVZWpuuvv15NTU2aMmWKtm3bpltvvVWS9PHHH2v27NkqLy/XokWLtGvXLt144406c+aM0tLSJElbtmzRxo0bdfbsWTmdvV8NEAgE5PF41NTUJLfb3d/2AfSi80JAR7f9WKGujh7rcr5bKteEScPUFYCRKpq/3wM6B6WpqUmSlJKSIkmqqKhQZ2en8vLywjWzZs1SVlaWysvLJUnl5eWaO3duOJxIUn5+vgKBgI4fP97t57S3tysQCERMAIZeV1uLBvB/GADot34HlFAopA0bNujaa6/VnDlzJEl+v19Op1PJyckRtWlpafL7/eGaPw8nF5dfXNad0tJSeTye8JSZmdnftgFEoauttfciABgC/Q4ohYWFOnbsmLZv3z6Y/XSrpKRETU1N4en06dND/pkA/uc5POxBARADjv68qaioSDt37tSBAwc0derU8Hyv16uOjg41NjZG7EWpq6uT1+sN1xw6dChifRev8rlY80Uul0sul6s/rQIYgLbAWfV2oqw9fpxsNu5YAGBwRfWtYoxRUVGRduzYoX379mn69OkRy+fPn6/4+Hjt3bs3PK+6ulo1NTXy+XySJJ/Pp6qqKtXX14dr9uzZI7fbrezs7IFsC4BB1nDiUK93kp04/WuyO8cNU0cAxoqo9qAUFhZq27Ztev311zVhwoTwOSMej0cJCQnyeDxau3atiouLlZKSIrfbrXvvvVc+n0+LFi2SJC1dulTZ2dm688479eSTT8rv9+uhhx5SYWEhe0mAEcjhSmQPCoBBF1VA2bx5syRp8eLFEfNffPFF3XXXXZKkp59+Wna7XQUFBWpvb1d+fr5eeOGFcG1cXJx27typ9evXy+fzafz48VqzZo0ee+yxgW0JgJiIcybKxnN4AAyyAd0HJVa4DwowPI7/v5/p/Gc1PdZkXXu7psy+Xva4fp3SBmAMGbb7oAAYvfr6fxeHazyHeAAMOr5VAHQr1NUhEwr1WmePHydxiAfAICOgAOhWsONCr1fwSJLNZuMcFACDjoACoFt9DSgAMBQIKAC6RUABEEsEFADdam8+1+tTjAFgqBBQAHQr8KeP1NXW3GONM2mi4lyJw9QRgLGEgAKg3xJSMuUcPzHWbQAYhQgoAPrNHu+SPS4+1m0AGIUIKAD6LS7eJRt3kAUwBAgoAPotLn4ce1AADAkCCoBLmFBQMn24i6zDKVtc3DB0BGCsIaAAuESoq0PBzt4vMbbZ7TyHB8CQ4JsFwCVCXR0KBbkHCoDYIaAAuESoq4ObtAGIKQIKgEuEujoJKABiioAC4BLtzZ+po/lcjzUO13i5Jkwepo4AjDUEFACX6DjfqM7zTT3WOBI9GpfsHaaOAIw1BBQA/WKPcyguflys2wAwShFQAPSLLc4hu5OAAmBoEFAA9Is9Ll5xDles2wAwShFQAEQwxkihPtxFNs6hOPagABgiBBQAEUwoqGBXe++FNrtsdm5zD2BoEFAARDChoILtF2LdBoAxjoACIIIJBRXsIKAAiC0CCoAIBBQAVkBAARChq61FLXUnY90GgDGOgAIgQqizXe2Bsz3W2ONdmnjZ1cPTEIAxiYACIGo2u0MuT2qs2wAwihFQAETNZrPJ4Rof6zYAjGIEFABRs9nsBBQAQyqqgFJaWqqFCxdqwoQJSk1N1cqVK1VdXR1Rs3jxYtlstojpnnvuiaipqanRihUrlJiYqNTUVD3wwAPq6uoa+NYAGB42u+JcibHuAsAo5oimuKysTIWFhVq4cKG6urr04x//WEuXLtWHH36o8eP/939Td999tx577LHw68TE//0iCwaDWrFihbxer9577z3V1tZq9erVio+P1+OPPz4ImwSgv4wxCna29V5os8nucA59QwDGrKgCyu7duyNeb926VampqaqoqND1118fnp+YmCiv19vtOn7729/qww8/1Ntvv620tDRdffXV+ulPf6qNGzfq0UcfldPJlx4QO0adbS19qrTZbEPcC4CxbEDnoDQ1NUmSUlJSIua//PLLmjx5subMmaOSkhKdP38+vKy8vFxz585VWlpaeF5+fr4CgYCOHz/e7ee0t7crEAhETACGgDHq6mNAAYChFNUelD8XCoW0YcMGXXvttZozZ054/ne/+11NmzZNGRkZOnr0qDZu3Kjq6mq9+uqrkiS/3x8RTiSFX/v9/m4/q7S0VJs2bepvqwCiECSgALCAfgeUwsJCHTt2TO+++27E/HXr1oV/njt3rtLT07VkyRKdPHlSV1xxRb8+q6SkRMXFxeHXgUBAmZmZ/WscwJcypu+HeABgKPXrEE9RUZF27typd955R1OnTu2xNjc3V5J04sQJSZLX61VdXV1EzcXXX3beisvlktvtjpgADD4T6tK5P5T3Wjdl9nXD0A2AsSyqgGKMUVFRkXbs2KF9+/Zp+vTpvb6nsrJSkpSeni5J8vl8qqqqUn19fbhmz549crvdys7OjqYdAIPNqE8PChzn5i6yAIZWVId4CgsLtW3bNr3++uuaMGFC+JwRj8ejhIQEnTx5Utu2bdMNN9ygSZMm6ejRo7rvvvt0/fXXKycnR5K0dOlSZWdn684779STTz4pv9+vhx56SIWFhXK5XIO/hQAGnSNhQqxbADDKRbUHZfPmzWpqatLixYuVnp4enn75y19KkpxOp95++20tXbpUs2bN0v3336+CggK98cYb4XXExcVp586diouLk8/n01//9V9r9erVEfdNAWBlNsUTUAAMsaj2oBhjelyemZmpsrKyXtczbdo0vfnmm9F8NAALcYxLinULAEY5nsUDICwU7OxTHc/hATDUCCgAwjovNPetkLvIAhhiBBQAYV19DSgAMMQIKADCutoIKACsgYACIIzn8ACwCgIKgLDaI7t6rZk8+5s8yRjAkCOgAAgLdnX0WuMcP1ESAQXA0CKgAIiKY1wS+QTAkCOgAIjK5zdpI6EAGFoEFABR4Tb3AIYDAQWApN4fZXGRw8Vt7gEMPQIKAElSsL1V6kNIsdntXMUDYMgRUABIkjrbWiX1bS8KAAw1AgoASZ/vQenrYR4AGGoEFACSpK62vh3iAYDhQEABIEnqOP/fMibUY40tLl42G18bAIYe3zQAJEnn/nBQJtjZY03KFQvlSOAqHgBDj4ACoM/iXImy2eJi3QaAMYCAAqDPHM4E2ex8bQAYeo5YNwBg4IwxCgaDA1tHHy4xtsW7FAyFFOrq6vfnOBx87QDoHd8UwChw8OBBXXfddQNaxwsbluvqGd4ea/7P9/+v3jp8st8X+0ycOFH19fX9ezOAMYWAAowCxhh1DWCvhtMRp77cHLa9o1Odnf3/nIH0CGBsIaAA0Phx8YqP+/zk186QU3Ud03QhNEF2BTXB0aBU5+kYdwhgrCGgAND4hHjFO+wKmjhVBJaqJThRncYlm4xc9vNqcP2nZo0/GOs2AYwhBBQAGj/Oqbi4eP2u8Ra1BpMlfX68x0hqC03QHy/MUZy6ZMy+WLYJYAzhekEAmuxJVFXbzRHh5M8ZxenEha/L33H5sPcGYGwioADQNbO/orSJSeounPyvPpxFCwCDhIACAAAsh4ACAAAsh4ACQJI0371b4+wtUrd3lA2pumqH/njy3eFuC8AYFVVA2bx5s3JycuR2u+V2u+Xz+bRr167w8ra2NhUWFmrSpElKSkpSQUGB6urqItZRU1OjFStWKDExUampqXrggQe4eRNgAfH2Tl038VdKimuQw9YhycimkJy288p0VStD+3WhvS3WbQIYI6K6zHjq1Kl64oknNHPmTBlj9NJLL+nmm2/WkSNHdNVVV+m+++7Tb37zG73yyivyeDwqKirSLbfcot/97neSpGAwqBUrVsjr9eq9995TbW2tVq9erfj4eD3++ONDsoEAelf1n/VKcH0sSeoMfaI/tc9Ua9Ajuy2oZEe9ml3/qSOf1CoY7Oc97gEgSjZj+vtUjc+lpKToqaee0q233qopU6Zo27ZtuvXWWyVJH3/8sWbPnq3y8nItWrRIu3bt0o033qgzZ84oLS1NkrRlyxZt3LhRZ8+eldPp7NNnBgIBeTwe3XXXXX1+DzCa1dXV6fXXX491G71yOp266667Yt0GgBjp6OjQ1q1b1dTUJLfb3WNtv2/UFgwG9corr6i1tVU+n08VFRXq7OxUXl5euGbWrFnKysoKB5Ty8nLNnTs3HE4kKT8/X+vXr9fx48f1ta99rdvPam9vV3t7e/h1IBCQJN15551KSkrq7yYAo0ZVVdWICCgul0tr166NdRsAYqSlpUVbt27tU23UAaWqqko+n09tbW1KSkrSjh07lJ2drcrKSjmdTiUnJ0fUp6Wlye/3S5L8fn9EOLm4/OKyL1NaWqpNmzZdMn/BggW9JjBgLAgGg7FuoU8cDoeuueaaWLcBIEYu7mDoi6iv4rnyyitVWVmpgwcPav369VqzZo0+/PDDaFcTlZKSEjU1NYWn06d5cBkAAKNZ1HtQnE6nZsyYIUmaP3++Dh8+rGeffVa33XabOjo61NjYGLEXpa6uTl6vV5Lk9Xp16NChiPVdvMrnYk13XC6XXC5XtK0CAIARasD3QQmFQmpvb9f8+fMVHx+vvXv3hpdVV1erpqZGPp9PkuTz+VRVVaX6+vpwzZ49e+R2u5WdnT3QVgAAwCgR1R6UkpISLV++XFlZWWpubta2bdu0f/9+vfXWW/J4PFq7dq2Ki4uVkpIit9ute++9Vz6fT4sWLZIkLV26VNnZ2brzzjv15JNPyu/366GHHlJhYSF7SAAAQFhUAaW+vl6rV69WbW2tPB6PcnJy9NZbb+nb3/62JOnpp5+W3W5XQUGB2tvblZ+frxdeeCH8/ri4OO3cuVPr16+Xz+fT+PHjtWbNGj322GODu1UAAGBEG/B9UGLh4n1Q+nIdNTAWlJeX6xvf+Eas2+jVxIkT1dDQEOs2AMRINH+/eRYPAACwHAIKAACwHAIKAACwHAIKAACwnH4/iweAdaSkpGjlypWxbqNXPDsLQF9xFQ8AABgWXMUDAABGNAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwHAIKAACwnKgCyubNm5WTkyO32y232y2fz6ddu3aFly9evFg2my1iuueeeyLWUVNToxUrVigxMVGpqal64IEH1NXVNThbAwAARgVHNMVTp07VE088oZkzZ8oYo5deekk333yzjhw5oquuukqSdPfdd+uxxx4LvycxMTH8czAY1IoVK+T1evXee++ptrZWq1evVnx8vB5//PFB2iQAADDS2YwxZiArSElJ0VNPPaW1a9dq8eLFuvrqq/XMM890W7tr1y7deOONOnPmjNLS0iRJW7Zs0caNG3X27Fk5nc4+fWYgEJDH41FTU5PcbvdA2gcAAMMkmr/f/T4HJRgMavv27WptbZXP5wvPf/nllzV58mTNmTNHJSUlOn/+fHhZeXm55s6dGw4nkpSfn69AIKDjx49/6We1t7crEAhETAAAYPSK6hCPJFVVVcnn86mtrU1JSUnasWOHsrOzJUnf/e53NW3aNGVkZOjo0aPauHGjqqur9eqrr0qS/H5/RDiRFH7t9/u/9DNLS0u1adOmaFsFAAAjVNQB5corr1RlZaWampr061//WmvWrFFZWZmys7O1bt26cN3cuXOVnp6uJUuW6OTJk7riiiv63WRJSYmKi4vDrwOBgDIzM/u9PgAAYG1RH+JxOp2aMWOG5s+fr9LSUs2bN0/PPvtst7W5ubmSpBMnTkiSvF6v6urqImouvvZ6vV/6mS6XK3zl0MUJAACMXgO+D0ooFFJ7e3u3yyorKyVJ6enpkiSfz6eqqirV19eHa/bs2SO32x0+TAQAABDVIZ6SkhItX75cWVlZam5u1rZt27R//3699dZbOnnypLZt26YbbrhBkyZN0tGjR3Xffffp+uuvV05OjiRp6dKlys7O1p133qknn3xSfr9fDz30kAoLC+VyuYZkAwEAwMgTVUCpr6/X6tWrVVtbK4/Ho5ycHL311lv69re/rdOnT+vtt9/WM888o9bWVmVmZqqgoEAPPfRQ+P1xcXHauXOn1q9fL5/Pp/Hjx2vNmjUR900BAAAY8H1QYoH7oAAAMPIMy31QAAAAhgoBBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWA4BBQAAWI4j1g30hzFGkhQIBGLcCQAA6KuLf7cv/h3vyYgMKM3NzZKkzMzMGHcCAACi1dzcLI/H02ONzfQlxlhMKBRSdXW1srOzdfr0abnd7li3NGIFAgFlZmYyjoOAsRw8jOXgYBwHD2M5OIwxam5uVkZGhuz2ns8yGZF7UOx2u77yla9IktxuN78sg4BxHDyM5eBhLAcH4zh4GMuB623PyUWcJAsAACyHgAIAACxnxAYUl8ulRx55RC6XK9atjGiM4+BhLAcPYzk4GMfBw1gOvxF5kiwAABjdRuweFAAAMHoRUAAAgOUQUAAAgOUQUAAAgOWMyIDy/PPP67LLLtO4ceOUm5urQ4cOxbolyzlw4IBuuukmZWRkyGaz6bXXXotYbozRww8/rPT0dCUkJCgvL0+ffPJJRE1DQ4NWrVolt9ut5ORkrV27Vi0tLcO4FbFXWlqqhQsXasKECUpNTdXKlStVXV0dUdPW1qbCwkJNmjRJSUlJKigoUF1dXURNTU2NVqxYocTERKWmpuqBBx5QV1fXcG5KTG3evFk5OTnhm1z5fD7t2rUrvJwx7L8nnnhCNptNGzZsCM9jPPvm0Ucflc1mi5hmzZoVXs44xpgZYbZv326cTqf5l3/5F3P8+HFz9913m+TkZFNXVxfr1izlzTffNH/7t39rXn31VSPJ7NixI2L5E088YTwej3nttdfMf/zHf5jvfOc7Zvr06ebChQvhmmXLlpl58+aZ999/3/z7v/+7mTFjhrnjjjuGeUtiKz8/37z44ovm2LFjprKy0txwww0mKyvLtLS0hGvuuecek5mZafbu3Ws++OADs2jRIvONb3wjvLyrq8vMmTPH5OXlmSNHjpg333zTTJ482ZSUlMRik2Li3/7t38xvfvMb84c//MFUV1ebH//4xyY+Pt4cO3bMGMMY9tehQ4fMZZddZnJycswPfvCD8HzGs28eeeQRc9VVV5na2trwdPbs2fByxjG2RlxAueaaa0xhYWH4dTAYNBkZGaa0tDSGXVnbFwNKKBQyXq/XPPXUU+F5jY2NxuVymV/84hfGGGM+/PBDI8kcPnw4XLNr1y5js9nMn/70p2Hr3Wrq6+uNJFNWVmaM+Xzc4uPjzSuvvBKu+eijj4wkU15eboz5PCza7Xbj9/vDNZs3bzZut9u0t7cP7wZYyMSJE80//dM/MYb91NzcbGbOnGn27Nlj/uIv/iIcUBjPvnvkkUfMvHnzul3GOMbeiDrE09HRoYqKCuXl5YXn2e125eXlqby8PIadjSynTp2S3++PGEePx6Pc3NzwOJaXlys5OVkLFiwI1+Tl5clut+vgwYPD3rNVNDU1SZJSUlIkSRUVFers7IwYy1mzZikrKytiLOfOnau0tLRwTX5+vgKBgI4fPz6M3VtDMBjU9u3b1draKp/Pxxj2U2FhoVasWBExbhK/k9H65JNPlJGRocsvv1yrVq1STU2NJMbRCkbUwwI/++wzBYPBiF8GSUpLS9PHH38co65GHr/fL0ndjuPFZX6/X6mpqRHLHQ6HUlJSwjVjTSgU0oYNG3Tttddqzpw5kj4fJ6fTqeTk5IjaL45ld2N9cdlYUVVVJZ/Pp7a2NiUlJWnHjh3Kzs5WZWUlYxil7du36/e//70OHz58yTJ+J/suNzdXW7du1ZVXXqna2lpt2rRJ1113nY4dO8Y4WsCICihALBUWFurYsWN69913Y93KiHTllVeqsrJSTU1N+vWvf601a9aorKws1m2NOKdPn9YPfvAD7dmzR+PGjYt1OyPa8uXLwz/n5OQoNzdX06ZN069+9SslJCTEsDNII+wqnsmTJysuLu6Ss6jr6urk9Xpj1NXIc3GsehpHr9er+vr6iOVdXV1qaGgYk2NdVFSknTt36p133tHUqVPD871erzo6OtTY2BhR/8Wx7G6sLy4bK5xOp2bMmKH58+ertLRU8+bN07PPPssYRqmiokL19fX6+te/LofDIYfDobKyMj333HNyOBxKS0tjPPspOTlZX/3qV3XixAl+Ly1gRAUUp9Op+fPna+/eveF5oVBIe/fulc/ni2FnI8v06dPl9XojxjEQCOjgwYPhcfT5fGpsbFRFRUW4Zt++fQqFQsrNzR32nmPFGKOioiLt2LFD+/bt0/Tp0yOWz58/X/Hx8RFjWV1drZqamoixrKqqigh8e/bskdvtVnZ29vBsiAWFQiG1t7czhlFasmSJqqqqVFlZGZ4WLFigVatWhX9mPPunpaVFJ0+eVHp6Or+XVhDrs3SjtX37duNyuczWrVvNhx9+aNatW2eSk5MjzqLG52f4HzlyxBw5csRIMv/wD/9gjhw5Yv7rv/7LGPP5ZcbJycnm9ddfN0ePHjU333xzt5cZf+1rXzMHDx407777rpk5c+aYu8x4/fr1xuPxmP3790dcinj+/PlwzT333GOysrLMvn37zAcffGB8Pp/x+Xzh5RcvRVy6dKmprKw0u3fvNlOmTBlTlyI++OCDpqyszJw6dcocPXrUPPjgg8Zms5nf/va3xhjGcKD+/CoeYxjPvrr//vvN/v37zalTp8zvfvc7k5eXZyZPnmzq6+uNMYxjrI24gGKMMT//+c9NVlaWcTqd5pprrjHvv/9+rFuynHfeecdIumRas2aNMebzS41/8pOfmLS0NONyucySJUtMdXV1xDrOnTtn7rjjDpOUlGTcbrf53ve+Z5qbm2OwNbHT3RhKMi+++GK45sKFC+b73/++mThxoklMTDR/9Vd/ZWprayPW88c//tEsX77cJCQkmMmTJ5v777/fdHZ2DvPWxM7f/M3fmGnTphmn02mmTJlilixZEg4nxjCGA/XFgMJ49s1tt91m0tPTjdPpNF/5ylfMbbfdZk6cOBFezjjGls0YY2Kz7wYAAKB7I+ocFAAAMDYQUAAAgOUQUAAAgOUQUAAAgOUQUAAAgOUQUAAAgOUQUAAAgOUQUAAAgOUQUAAAgOUQUAAAgOUQUAAAgOUQUAAAgOX8fzD2Jjs3QpOEAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Gym",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
